Adding a New Policy
Overview
This guide explains how to add a new policy to LW-BenchHub. All policies inherit from the BasePolicy class, which provides a standardized interface for policy deployment and evaluation.
Policy Framework Architecture
Base Policy Class
The BasePolicy class (policy/base.py) defines the common interface that all policies must implement:
class BasePolicy(ABC):
"""Base Policy Class - All policies should inherit from this class"""
def __init__(self, usr_args: Dict[str, Any]):
"""Initialize policy with user arguments"""
@abstractmethod
def get_model(self, usr_args: Dict[str, Any]) -> Any:
"""Load and initialize the policy model"""
@abstractmethod
def get_action(self) -> Any:
"""Get action from the policy model"""
@abstractmethod
def eval(self, task_env, observation, usr_args, video_writer) -> bool:
"""Evaluate policy on a task"""
@abstractmethod
def reset_model(self) -> None:
"""Reset model state between episodes"""
Provided Utility Methods
The base class provides several utility methods that you can use:
encode_obs(observation): Preprocess observation data (handles tensor conversion and reshaping)add_video_frame(video_writer, obs, camera_key): Add frames to video recordingstep_environment(task_env, action, usr_args): Execute environment step with action mapping supportget_instruction(): Get task instruction from user arguments
Step-by-Step Guide: Adding a New Policy
Step 1: Create Policy Directory
Create a new directory under policy/ for your policy:
policy/
├── base.py
├── GR00T/
├── PI/
└── YourPolicy/ # Your new policy
├── your_policy.py
└── deploy_policy.yml
Step 2: Implement Your Policy Class
Create your_policy.py that inherits from BasePolicy:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Your Policy Implementation
"""
import torch
import sys
import os
from typing import Dict, Any
from policy.base import BasePolicy
# Add current directory to path if needed
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)
sys.path.append(parent_directory)
# Import your policy library
try:
from your_policy_library import YourPolicyModel
except ImportError as e:
print(f"Your policy library not found: {e}")
class YourPolicy(BasePolicy):
"""Your Policy Implementation"""
def __init__(self, usr_args: Dict[str, Any]):
"""Initialize your policy"""
super().__init__(usr_args)
def get_model(self, usr_args: Dict[str, Any]):
"""
Load and initialize your policy model
This method is called during initialization.
Load your model checkpoint and set up any required configurations.
"""
# Extract configuration from usr_args
checkpoint = usr_args.get("checkpoint") or usr_args.get("ckpt_setting")
observation_config = usr_args.get("observation_config", {})
# Set up device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Store configurations
self.observation_config = observation_config
# Load your model
self.model = YourPolicyModel(
checkpoint=checkpoint,
device=self.device
)
print("Successfully loaded your policy model!")
def encode_obs(self, observation: Dict[str, Any]) -> Dict[str, Any]:
"""
Encode observation into format expected by your model
Override this method if you need custom observation preprocessing.
"""
# Use parent's encode_obs for standard preprocessing
observation = super().encode_obs(observation, transpose=True, keep_dim_env=False)
# Build observation window for your model
observation = self._build_observation_window(observation)
return observation
def _build_observation_window(self, obs: Dict[str, Any]) -> Dict[str, Any]:
"""
Build observation window with custom mapping
Map observation keys to your model's expected input format.
"""
custom_mapping = self.observation_config.get("custom_mapping", {})
obs_window = {"instruction": self.instruction}
# Map observations according to custom_mapping
for key, mapping in custom_mapping.items():
if isinstance(mapping, dict):
# Handle nested mapping
obs_window[key] = {k: obs[v] for k, v in mapping.items()}
else:
# Direct mapping
obs_window[key] = obs[mapping]
return obs_window
def get_action(self) -> torch.Tensor:
"""
Get action from your policy model
Returns:
Action tensor to execute in the environment
"""
# Get action from your model
action = self.model.predict(self.observation_window)
# Post-process if needed
if isinstance(action, np.ndarray):
action = torch.from_numpy(action).float().to(self.device)
return action
def eval(self, task_env: Any, observation: Dict[str, Any],
usr_args: Dict[str, Any], video_writer: Any) -> bool:
"""
Evaluate your policy on a task
Args:
task_env: Task environment (RemoteEnv instance)
observation: Initial observation
usr_args: User arguments (contains time_out_limit, record_camera, etc.)
video_writer: Video writer for recording
Returns:
Whether the task was completed successfully
"""
terminated = False
# Main evaluation loop
for step in range(usr_args['time_out_limit']):
# Encode observation
self.observation_window = self.encode_obs(observation)
# Get action from policy
actions = self.get_action()
# Execute action(s)
# If your policy outputs action chunks, iterate through them
if actions.dim() > 1: # Multiple actions (chunk)
for i in range(actions.shape[0]):
observation, terminated = self.step_environment(
task_env, actions[i], usr_args
)
self.add_video_frame(
video_writer, observation, usr_args['record_camera']
)
if terminated:
return terminated
else: # Single action
observation, terminated = self.step_environment(
task_env, actions, usr_args
)
self.add_video_frame(
video_writer, observation, usr_args['record_camera']
)
if terminated:
return terminated
return terminated
def reset_model(self) -> None:
"""
Reset model state between episodes
Clear any internal state or observation buffers.
"""
self.observation_window = None
# Reset any other model-specific state
if hasattr(self.model, 'reset'):
self.model.reset()
print("Model state reset successfully")
Step 3: Create Configuration File
Create deploy_policy.yml to define policy parameters. The configuration now includes both policy settings and environment configuration:
# Policy Configuration
policy_name: 'YourPolicy' # Policy class name (must match your policy class)
seed: 0
# Model Configuration
ckpt_setting: 'path/to/checkpoint' # Path to trained policy checkpoint
instruction: "Your task instruction" # Task instruction/prompt
# Policy-specific parameters
your_custom_param1: value1
your_custom_param2: value2
# Observation Configuration
observation_config:
custom_mapping:
# Map observation keys from LW-BenchHub to your model's expected format
images/front: global_camera # Camera observations
images/wrist: hand_camera
state: joint_pos # State observations
action: joint_target_pos
# Evaluation Configuration
record_camera: ["global_camera", "hand_camera"]
time_out_limit: 500
height: 480 # Camera image height
width: 480 # Camera image width
# Environment Configuration (sent to server via attach())
env_cfg:
task: YourTask # Task name
robot: LeRobot-AbsJointGripper-RL # Robot type
layout: robocasakitchen # Scene layout
scene_backend: robocasa # Scene backend
task_backend: robocasa # Task backend
device: cuda:0 # Device for simulation
num_envs: 1 # Number of parallel environments
enable_cameras: true # Enable camera observations
usd_simplify: false # USD simplification
video: false # Record video in environment
seed: 42 # Random seed
for_rl: false # RL mode (false for policy evaluation)
variant: Visual # Observation variant (Visual/State)
concatenate_terms: false # Concatenate observation terms
distributed: false # Multi-GPU training mode
Step 4: Register Your Policy
Add your policy to policy/__init__.py:
from policy.YourPolicy.your_policy import YourPolicy
Common Patterns and Best Practices
1. Observation Preprocessing
Different policies may require different observation formats:
def encode_obs(self, observation: Dict[str, Any]) -> Dict[str, Any]:
# Option 1: Standard preprocessing (used by PI)
observation = super().encode_obs(observation, transpose=True, keep_dim_env=False)
# Option 2: Keep environment dimension (used by GR00T)
observation = super().encode_obs(observation, transpose=False, keep_dim_env=True)
# Then build your custom observation window
observation = self._build_observation_window(observation)
return observation
2. Action Chunking
If your policy predicts multiple future actions:
def eval(self, task_env, observation, usr_args, video_writer):
for _ in range(usr_args['time_out_limit']):
observation = self.encode_obs(observation)
actions = self.get_action() # Shape: (chunk_size, action_dim)
# Execute all actions in the chunk
for i in range(actions.shape[0]):
observation, terminated = self.step_environment(
task_env, actions[i], usr_args
)
if terminated:
return terminated
return terminated
3. Joint Mapping
If your policy's action space differs from the robot's:
def step_environment(self, task_env, action, usr_args):
# Apply joint mapping if provided
if 'joint_mapping' in usr_args:
action = action[usr_args['joint_mapping']]
# Convert to tensor if needed
if isinstance(action, np.ndarray):
action = torch.from_numpy(action).float().cuda()
obs, _, terminated, _, _ = task_env.step(action.unsqueeze(0))
return obs, terminated
4. Error Handling
Always include proper import error handling:
try:
from your_policy_library import YourModel
except ImportError as e:
print(f"Policy library not found. Please install it first: {e}")
print("Installation: pip install your-policy-library")
Troubleshooting
Common Issues
1. Import Errors
# Add proper path management
import sys
import os
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(current_file_path)
sys.path.append(parent_directory)
2. Observation Shape Mismatch
# Debug observation shapes
def encode_obs(self, observation):
print(f"Input obs keys: {observation['policy'].keys()}")
for k, v in observation['policy'].items():
print(f"{k}: {v.shape if torch.is_tensor(v) else type(v)}")
observation = super().encode_obs(observation)
# ... rest of encoding
3. Action Space Mismatch
# Verify action dimensions
def get_action(self):
action = self.model.predict(self.observation_window)
expected_dim = 7 # Your robot's action dimension
assert action.shape[-1] == expected_dim, \
f"Action dim mismatch: got {action.shape[-1]}, expected {expected_dim}"
return action
4. Environment Not Attached
# Make sure to call attach() before using the environment
env = RemoteEnv.make(address=('127.0.0.1', 50000), authkey=b'lightwheel')
env.attach(env_cfg) # Don't forget this!
obs, _ = env.reset()
Summary
To add a new policy to LW-BenchHub:
- Create a new directory under
policy/ - Implement your policy class inheriting from
BasePolicy - Override the four abstract methods:
get_model,get_action,eval,reset_model - Create a configuration YAML file with policy parameters and
env_cfgsection - Register your policy in
policy/__init__.py - Test your implementation before deployment
The framework provides flexibility through:
- Standardized interfaces via
BasePolicy - Utility methods for common operations
- Flexible observation preprocessing
- Action chunking support
- Custom configuration management
- Attach/Detach architecture for dynamic environment configuration
Refer to GR00T and PI implementations for real-world examples of different policy architectures and patterns.